#define ARMA_DONT_PRINT_ERRORS

#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;
using namespace arma;



//[[Rcpp:plugins(cpp11)]]
#include <omp.h>
//[[Rcpp:plugins(openmp)]]

const double log2pi = std::log(2.0 * M_PI);




double am_mult2(const arma::mat& rhs)
{
  arma::mat ip= rhs.t();
 return (ip(0));
}


arma::mat mvrnormArma(int n, arma::mat mu, arma::mat sigma,
		std::function<double()> rn = norm_rand ) {
         int ncols = sigma.n_cols;
	arma::mat Y(n,ncols);
        Y.imbue(rn);
     arma::mat Z=arma::repmat(mu, 1, n).t() + Y * arma::chol(sigma);     
     return (Z.t());
}


arma::mat defaultRNG(int n, arma::mat mu, arma::mat sigma) {
  return mvrnormArma(n, mu, sigma);
}


arma::mat serial(int n, const arma::mat mu, arma::mat sigma, const arma::vec& s) {
  std::normal_distribution<double> nor(0, 1);
std::mt19937 engine(
static_cast<uint64_t>(s[omp_get_thread_num()] ) );
  return mvrnormArma(n, mu, sigma, [&](){return nor(engine);});
}





arma::vec dmvnrm_arma(arma::mat x,  
                      arma::rowvec mean,  
                      arma::mat sigma, 
                      bool logd = false) { 
    int n = x.n_rows;
    int xdim = x.n_cols;
    arma::vec out(n);
    arma::mat rooti = arma::trans(arma::inv(trimatu(arma::chol(sigma))));
    double rootisum = arma::sum(log(rooti.diag()));
    double constants = -(static_cast<double>(xdim)/2.0) * log2pi;
    
    for (int i=0; i < n; ++i) {
        arma::vec z = rooti * arma::trans( x.row(i) - mean) ;    
        out(i)      = constants - 0.5 * arma::sum(z%z) + rootisum;     
    }  
      
    if (logd == false) {
        out = exp(out);
    }
    return(out);
}







double vm_convert(const arma::mat& rhs)
{
  arma::mat ip= rhs;
 return (ip(0));
}




// [[Rcpp::export()]]
Rcpp::List mcmc_logit_index (arma::mat y,
arma::mat X,
arma::mat I_party,
arma::mat I_country,
arma::mat I_outcome,
arma::mat Cov_Beta,
arma::mat beta2start,
arma::mat beta3start,
arma::mat reffect_partystart,
arma::mat reffect_countrystart,
arma::rowvec densitybeta,
arma::mat DiagWish, 
int mcmc = 100,
int burn=10,
int thin=10,
int chains=3
 ) {

 // Dimensions
int r= X.n_cols ;
int N=X.n_rows;
int n_country=I_country.n_cols;
int n_party=I_party.n_cols;


 // Current Containers

  arma::mat beta2 = beta2start ;
  arma::mat beta3 = beta3start ;
  arma::mat beta2new(r,1);
  beta2new.fill(0.0);
  arma::mat beta3new(r,1);
  beta3new.fill(0.0);

 arma::mat reffect_party = reffect_partystart ;
 arma::mat reffect_country = reffect_countrystart ;
 



 arma::mat etaold_beta(N, 3) ;
 etaold_beta.fill(0.0) ;
 arma::mat etanew_beta(N, 3) ;
 etanew_beta.fill(0.0) ;




 arma::mat etaold_country(N, 3) ;
 etaold_country.fill(0.0) ;
 arma::mat etanew_country(N, 3) ;
 etanew_country.fill(0.0) ;
 arma::mat lcountry(n_country, 1) ;
 lcountry.fill(0.0) ;


 arma::mat etaold_party(N, 3) ;
 etaold_party.fill(0.0) ;
 arma::mat etanew_party(N, 3) ;
 etanew_party.fill(0.0) ;
 arma::mat lparty(n_party, 1) ;
 lparty.fill(0.0) ;




arma::mat sigmabetadensity(r,r,arma::fill::eye);

arma::mat reffect_country_new(n_country,2);
arma::mat reffect_party_new(n_party,2);

 arma::mat mu0(1, 2) ;
 mu0.fill(0.0) ;

double var=1.0;
arma::vec varscountry(2) ;
varscountry.fill(var) ;
 arma::mat sigma2country(2, 2) ;
 sigma2country.fill(0.0) ;
 sigma2country.diag() = varscountry ;

arma::vec varsparty(2) ;
varsparty.fill(var) ;
 arma::mat sigma2party(2, 2) ;
 sigma2party.fill(0.0) ;
 sigma2party.diag() = varsparty ;


 int lastit= (mcmc-burn)/thin	;

// Trace Containers
  arma::cube tracebeta2(r, lastit,chains) ;
  arma::cube tracebeta3(r, lastit,chains) ;
 arma::cube tracesigma2country(4, lastit,chains) ;
 arma::cube tracereffect_country(n_country*2, lastit,chains) ;
 arma::cube tracesigma2party(4, lastit,chains) ;
 arma::cube tracereffect_party(n_party*2, lastit,chains) ;



arma::vec seeds = linspace(0, 2, chains);



 
#pragma omp parallel num_threads(chains)
  {


std::mt19937 engine(
static_cast<uint64_t>(seeds[omp_get_thread_num()] ) );
    std::uniform_real_distribution<double> uni(0.0, 1.0);
     std::normal_distribution<double> nor(0, 1);

  std::random_device rd;
    std::mt19937 gen(rd());




#pragma omp for 
for (int chain = 0; chain < chains; ++chain) {

 // START SAMPLING
for (int iter = 0 ; iter < mcmc ; ++iter) {


etaold_beta.col(0)=1/(1+exp(X*beta2+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))+
exp(X*beta3+I_party*reffect_party.col(1)+I_country*reffect_country.col(1)));
etaold_beta.col(1)=exp(X*beta2+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))%etaold_beta.col(0);
etaold_beta.col(2)=exp(X*beta3+I_party*reffect_party.col(1)+I_country*reffect_country.col(1))%etaold_beta.col(0);


beta2new=mvrnormArma(1, beta2, Cov_Beta, [&](){return nor(engine);});
beta3new=mvrnormArma(1, beta3, Cov_Beta, [&](){return nor(engine);});


etanew_beta.col(0)=1/(1+exp(X*beta2new+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))+
exp(X*beta3new+I_party*reffect_party.col(1)+I_country*reffect_country.col(1)));
etanew_beta.col(1)=exp(X*beta2new+I_party*reffect_party.col(0)+I_country*reffect_country.col(0))%etanew_beta.col(0);
etanew_beta.col(2)=exp(X*beta3new+I_party*reffect_party.col(1)+I_country*reffect_country.col(1))%etanew_beta.col(0);

double accept_beta=vm_convert((sum(I_outcome.col(0)%log(etanew_beta.col(0)),0)+sum(I_outcome.col(1)%log(etanew_beta.col(1)),0)+sum(I_outcome.col(2)%log(etanew_beta.col(2)),0))-
			(sum(I_outcome.col(0)%log(etaold_beta.col(0)),0)+sum(I_outcome.col(1)%log(etaold_beta.col(1)),0)+sum(I_outcome.col(2)%log(etaold_beta.col(2)),0)))+
(am_mult2(dmvnrm_arma(beta2new.t(),densitybeta, sigmabetadensity, true))-am_mult2(dmvnrm_arma(beta2.t(),densitybeta, sigmabetadensity, true)))+
(am_mult2(dmvnrm_arma(beta3new.t(),densitybeta, sigmabetadensity, true))-am_mult2(dmvnrm_arma(beta3.t(),densitybeta, sigmabetadensity, true)));

if(log(uni(engine))<accept_beta) {
for (int l=0; l<r; ++l) {
beta2(l)=beta2new(l);
beta3(l)=beta3new(l);
}
}


 // TRACE
if ((iter+1)> burn & (iter+1)%thin==0) {
int j=((iter)-burn)/thin;
tracebeta2.subcube(0, j, chain, (r-1), j, chain)=beta2;
tracebeta3.subcube(0, j, chain, (r-1), j, chain)=beta3;
tracereffect_country.subcube(0, j, chain, 2*n_country-1, j, chain)=vectorise(reffect_country);
tracesigma2country.subcube(0, j, chain, 3, j, chain)=vectorise(sigma2country);
tracereffect_party.subcube(0, j, chain, 2*n_party-1, j, chain)=vectorise(reffect_party);
tracesigma2party.subcube(0, j, chain, 3, j, chain)=vectorise(sigma2party);

} //end trace
}

}  // end iterations
} // end chains

// Returns
Rcpp::List ret;
 ret["beta2"] = tracebeta2 ;
 ret["beta3"] = tracebeta3 ;
ret["reffect_country"] = tracereffect_country;
ret["sigma2country"]=tracesigma2country;
ret["reffect_party"] = tracereffect_party;
ret["sigma2party"]=tracesigma2party;

 return(ret) ;
 

}

